#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>
#include <string.h>
#include <time.h>
#include <math.h>
#include <openssl/evp.h>

#include "rng.h"
#include "api.h"
#include "gmp.h"
#include "kaz_api.h"

void HashMsg(const unsigned char *msg, unsigned int mlen, unsigned char buf[CRYPTO_BYTES])
{
    // Initialize the digest context and compute the hash
	EVP_MD_CTX *mdctx = EVP_MD_CTX_new();
    const EVP_MD *md = EVP_sha512();

    EVP_DigestInit_ex(mdctx, md, NULL);
    EVP_DigestUpdate(mdctx, msg, mlen);
    EVP_DigestFinal_ex(mdctx, buf, &mlen);

    // Clean up
    EVP_MD_CTX_free(mdctx);
}

void KAZ_DS_CRT(int size, mpz_t *x, mpz_t *modulus, mpz_t crt)
{
    mpz_t *c=malloc(size*sizeof(mpz_t));
    mpz_t u, prod;

    mpz_inits(u, prod, NULL);
    for(int i=0; i<size; i++) mpz_init(c[i]);

    mpz_set_ui(c[0], 0);

    for(int i=1; i<size; i++){
        mpz_set_ui(c[i], 1);

        for(int j=0; j<=i-1; j++){
            mpz_invert(u, modulus[j], modulus[i]);
            mpz_mul(c[i], c[i], u);
            mpz_mod(c[i], c[i], modulus[i]);
        }
    }

    mpz_set(u, x[0]);
    mpz_set(crt, u);

    for(int i=1; i<size; i++){
        mpz_sub(u, x[i], crt);
        mpz_mul(u, u, c[i]);
        mpz_mod(u, u, modulus[i]);
        mpz_set_ui(prod, 1);

        for(int j=0; j<=i-1; j++) mpz_mul(prod, prod, modulus[j]);

        mpz_mul(u, u, prod);
        mpz_add(crt, crt, u);
    }

    for(int i=0; i<size; i++) mpz_clear(c[i]);
    mpz_clears(u, prod, NULL);
}

void KAZ_DS_KEYGEN(unsigned char *kaz_ds_verify_key, unsigned char *kaz_ds_sign_key)
{
    mpz_t G1, phiG1, phiphiG1, q, Q, phiQ, qQ, G1qQ;
    mpz_t a, b, ALPHA, OMEGA1, V1, V2, SK, tmp;

    mpz_inits(G1, phiG1, phiphiG1, q, Q, phiQ, qQ, G1qQ, NULL);
    mpz_inits(a, b, ALPHA, OMEGA1, V1, V2, SK, tmp, NULL);

    //Get all system parameters and precomputed parameters
    mpz_set_str(G1, KAZ_DS_SP_G1, 10);
    mpz_set_str(q, KAZ_DS_SP_q, 10);
    mpz_set_str(Q, KAZ_DS_SP_Q, 10);
	
	mpz_set_str(G1qQ, KAZ_DS_SP_G1qQ, 10);
	mpz_set_str(phiG1, KAZ_DS_SP_PHIG1, 10);
	mpz_set_str(phiphiG1, KAZ_DS_SP_PHIPHIG1, 10);
	mpz_set_str(phiQ, KAZ_DS_SP_PHIQ, 10);
	mpz_set_str(qQ, KAZ_DS_SP_qQ, 10);

	int k=KAZ_DS_SP_K;
	int LG1=KAZ_DS_SP_LG1;
	
	unsigned char RAN[100];
	
    //1) Generate random ALPHA	
	randombytes(RAN, (k+LG1)/8);
	mpz_import(ALPHA, (k+LG1)/8, 1, sizeof(char), 0, 0, RAN);
	mpz_mul_ui(ALPHA, ALPHA, 2);
	
	//2) Compute V1
	mpz_mod(V1, ALPHA, G1);
	
	//3) Generate random prime a,  random OMEGA1
	randombytes(RAN, KAZ_DS_SP_RAN);
	mpz_import(a, KAZ_DS_SP_RAN, 1, sizeof(char), 0, 0, RAN);	
	mpz_nextprime(a, a);

	randombytes(RAN, KAZ_DS_SP_RAN);
	mpz_import(OMEGA1, KAZ_DS_SP_RAN, 1, sizeof(char), 0, 0, RAN);
	
	//4) Compute b
	mpz_mul(tmp, OMEGA1, phiG1);
	mpz_powm(b, a, phiphiG1, tmp);

	//5) Compute V2
	mpz_mul(tmp, phiQ, b);
	mpz_powm(V2, ALPHA, tmp, qQ);
	mpz_mul(V2, V2, Q);
	mpz_mod(V2, V2, qQ);

	//6) Compute Secret Signing SK
	mpz_powm(SK, ALPHA, tmp, G1qQ);

    //6) Set kaz_ds_sign_key=(SK) & kaz_ds_verify_key=(V1, V2)
    size_t SKSIZE=mpz_sizeinbase(SK, 16);
	size_t V1SIZE=mpz_sizeinbase(V1, 16);
	size_t V2SIZE=mpz_sizeinbase(V2, 16);

	unsigned char *SKBYTE=(unsigned char*) malloc(SKSIZE*sizeof(unsigned char));
	mpz_export(SKBYTE, &SKSIZE, 1, sizeof(char), 0, 0, SK);

	unsigned char *V1BYTE=(unsigned char*) malloc(V1SIZE*sizeof(unsigned char));
	mpz_export(V1BYTE, &V1SIZE, 1, sizeof(char), 0, 0, V1);

	unsigned char *V2BYTE=(unsigned char*) malloc(V2SIZE*sizeof(unsigned char));
	mpz_export(V2BYTE, &V2SIZE, 1, sizeof(char), 0, 0, V2);

	for(int i=0; i<CRYPTO_SECRETKEYBYTES; i++) kaz_ds_sign_key[i]=0;

	int je=CRYPTO_SECRETKEYBYTES-1;
	for(int i=SKSIZE-1; i>=0; i--){
		kaz_ds_sign_key[je]=SKBYTE[i];
		je--;
	}

	for(int i=0; i<CRYPTO_PUBLICKEYBYTES; i++) kaz_ds_verify_key[i]=0;

	je=CRYPTO_PUBLICKEYBYTES-1;
	for(int i=V2SIZE-1; i>=0; i--){
		kaz_ds_verify_key[je]=V2BYTE[i];
		je--;
	}

	je=CRYPTO_PUBLICKEYBYTES-KAZ_DS_V2BYTES-1;
	for(int i=V1SIZE-1; i>=0; i--){
		kaz_ds_verify_key[je]=V1BYTE[i];
		je--;
	}
//gmp_printf("ALPHA=%Zd\n", ALPHA);gmp_printf("a=%Zd\n", a);gmp_printf("b=%Zd\n", b);gmp_printf("OMEGA1=%Zd\n", OMEGA1);gmp_printf("V1=%Zd\n", V1);gmp_printf("V2=%Zd\n", V2);gmp_printf("SK=%Zd\n", SK);
	mpz_clears(G1, phiG1, phiphiG1, q, Q, phiQ, qQ, G1qQ, NULL);
    mpz_clears(a, b, ALPHA, OMEGA1, V1, V2, SK, tmp, NULL);
}

int KAZ_DS_SIGNATURE(unsigned char *signature, unsigned long long *signlen, const unsigned char *m, unsigned long long mlen, const unsigned char *sk)
{
    mpz_t G1, phiG1, phiphiG1, q, Q, qQ, phiqQ, G1qQ, SK;
    mpz_t hashValue, r, OMEGA2, BETA, S, tmp;

    mpz_inits(G1, phiG1, phiphiG1, q, Q, qQ, phiqQ, G1qQ, SK, NULL);
    mpz_inits(hashValue, r, OMEGA2, BETA, S, tmp, NULL);
	
    // Get all system parameters and precomputed parameters
	mpz_set_str(G1, KAZ_DS_SP_G1, 10);
	mpz_set_str(q, KAZ_DS_SP_q, 10);
	mpz_set_str(Q, KAZ_DS_SP_Q, 10);
	
	mpz_set_str(G1qQ, KAZ_DS_SP_G1qQ, 10);
	mpz_set_str(phiG1, KAZ_DS_SP_PHIG1, 10);
	mpz_set_str(phiphiG1, KAZ_DS_SP_PHIPHIG1, 10);
	mpz_set_str(qQ, KAZ_DS_SP_qQ, 10);
	mpz_set_str(phiqQ, KAZ_DS_SP_PHIqQ, 10);

    //1) Get kaz_ds_sign_key=(SK)
	unsigned char *SKBYTE=(unsigned char*) malloc((KAZ_DS_SKBYTES)*sizeof(unsigned char));

	for(int i=0; i<KAZ_DS_SKBYTES; i++) SKBYTE[i]=0;
	for(int i=0; i<KAZ_DS_SKBYTES; i++){SKBYTE[i]=sk[i];}

	mpz_import(SK, KAZ_DS_SKBYTES, 1, sizeof(char), 0, 0, SKBYTE);

	//2) Compute HASHValue(m)
	unsigned char buf[CRYPTO_BYTES];
	HashMsg(m, mlen, buf);
	mpz_import(hashValue, CRYPTO_BYTES, 1, sizeof(char), 0, 0, buf);
	//mpz_nextprime(hashValue, hashValue);

	unsigned char RAN[100];
	int lenS=0, lenG1qQ=mpz_sizeinbase(G1qQ, 2);
	
	do{
		//3) Generate random prime r, random OMEGA2
		randombytes(RAN, KAZ_DS_SP_RAN);
		mpz_import(r, KAZ_DS_SP_RAN, 1, sizeof(char), 0, 0, RAN);
		mpz_nextprime(r, r);

		randombytes(RAN, KAZ_DS_SP_RAN);
		mpz_import(OMEGA2, KAZ_DS_SP_RAN, 1, sizeof(char), 0, 0, RAN);

		//4) Compute BETA
		mpz_mul(tmp, OMEGA2, phiG1);
		mpz_powm(BETA, r, phiphiG1, tmp);

		//5) Compute Signature
		mpz_mul(tmp, phiqQ, BETA);
		mpz_powm(S, hashValue, tmp, G1qQ);
		mpz_mul(S, S, SK);
		mpz_mod(S, S, G1qQ);
		
		lenS=mpz_sizeinbase(S, 2);
	}while(lenS!=lenG1qQ);

    //6) Set signature=(S, m)
    size_t SSIZE=mpz_sizeinbase(S, 16);

	unsigned char *SBYTE=(unsigned char*) malloc(SSIZE*sizeof(unsigned char));
	mpz_export(SBYTE, &SSIZE, 1, sizeof(char), 0, 0, S);

	for(int i=0; i<mlen+KAZ_DS_SBYTES; i++) signature[i]=0;

	int je=mlen+KAZ_DS_SBYTES-1;
	for(int i=mlen-1; i>=0; i--){
		signature[je]=m[i];
		je--;
	}

	je=KAZ_DS_SBYTES-1;
	for(int i=SSIZE-1; i>=0; i--){
		signature[je]=SBYTE[i];
		je--;
	}

	*signlen=mlen+KAZ_DS_SBYTES;

    free(SBYTE);
//gmp_printf("hashValue=%Zd\n", hashValue);gmp_printf("r=%Zd\n", r);gmp_printf("OMEGA2=%Zd\n", OMEGA2);gmp_printf("BETA=%Zd\n", BETA);gmp_printf("S=%Zd\n", S);
	mpz_clears(G1, phiG1, phiphiG1, q, Q, qQ, phiqQ, G1qQ, SK, NULL);
    mpz_clears(hashValue, r, OMEGA2, BETA, S, tmp, NULL);

	return 0;
}

int KAZ_DS_VERIFICATION(unsigned char *m, unsigned long long *mlen, const unsigned char *sm, unsigned long long smlen, const unsigned char *pk)
{
    mpz_t G0, R, G1, q, Q, phiQ, G1Q, G1qQ, qQ, phiqQ, hashValue, V1, V2, S;
    mpz_t tmp, tmp2, SF1, SF2, e, W0, W2, W3, W4, W5, VQ, y, y1, y2;

    mpz_inits(G0, R, G1, q, Q, phiQ, G1Q, G1qQ, qQ, phiqQ, hashValue, V1, V2, S, NULL);
    mpz_inits(tmp, tmp2, SF1, SF2, e, W0, W2, W3, W4, W5, VQ, y, y1, y2, NULL);

    // Get all system parameters and precomputed parameters
    mpz_set_str(G0, KAZ_DS_SP_G0, 10);
    mpz_set_str(R, KAZ_DS_SP_R, 10);
    mpz_set_str(G1, KAZ_DS_SP_G1, 10);
    mpz_set_str(q, KAZ_DS_SP_q, 10);
    mpz_set_str(Q, KAZ_DS_SP_Q, 10);
    
	mpz_set_str(G1Q, KAZ_DS_SP_G1Q, 10);
	mpz_set_str(G1qQ, KAZ_DS_SP_G1qQ, 10);
	mpz_set_str(phiQ, KAZ_DS_SP_PHIQ, 10);
	mpz_set_str(qQ, KAZ_DS_SP_qQ, 10);
	mpz_set_str(phiqQ, KAZ_DS_SP_PHIqQ, 10);

    //1) Get kaz_ds_verify_key=(V1, V2)
	unsigned char *V1BYTE=(unsigned char*) malloc((KAZ_DS_V1BYTES)*sizeof(unsigned char));
	unsigned char *V2BYTE=(unsigned char*) malloc((KAZ_DS_V2BYTES)*sizeof(unsigned char));

	for(int i=0; i<KAZ_DS_V1BYTES; i++) V1BYTE[i]=0;
	for(int i=0; i<KAZ_DS_V2BYTES; i++) V2BYTE[i]=0;

	for(int i=0; i<KAZ_DS_V1BYTES; i++){V1BYTE[i]=pk[i];}
	for(int i=0; i<KAZ_DS_V2BYTES; i++){V2BYTE[i]=pk[i+KAZ_DS_V1BYTES];}

	mpz_import(V1, KAZ_DS_V1BYTES, 1, sizeof(char), 0, 0, V1BYTE);
	mpz_import(V2, KAZ_DS_V2BYTES, 1, sizeof(char), 0, 0, V2BYTE);

    //2) Get signature=(S, m)
    int len=smlen-KAZ_DS_SBYTES;

    unsigned char *SBYTE=(unsigned char*) malloc(KAZ_DS_SBYTES*sizeof(unsigned char));
	unsigned char *MBYTE=(unsigned char*) malloc(len*sizeof(unsigned char));

	for(int i=0; i<KAZ_DS_SBYTES; i++) SBYTE[i]=0;
	for(int i=0; i<len; i++) MBYTE[i]=0;

	for(int i=0; i<KAZ_DS_SBYTES; i++){SBYTE[i]=sm[i];}
	for(int i=0; i<len; i++){MBYTE[i]=sm[i+KAZ_DS_SBYTES];}

    mpz_import(S, KAZ_DS_SBYTES, 1, sizeof(char), 0, 0, SBYTE);

	//3) Compute the hash value of the message
    unsigned char buf[CRYPTO_BYTES]={0};
    HashMsg(MBYTE, len, buf);
    mpz_import(hashValue, CRYPTO_BYTES, 1, sizeof(char), 0, 0, buf);

    //4) Filtering Procedures
    //FILTER 1
    mpz_mod(tmp, S, G1qQ);
    mpz_sub(W0, tmp, S);

    if(mpz_cmp_ui(W0, 0)!=0){
        printf("Filter 1...\n");
        return -4;
    }

	//FILTER 2
	int lenS=mpz_sizeinbase(S, 2);
	int lenG1qQ=mpz_sizeinbase(G1qQ, 2);
	int W1=lenS-lenG1qQ;
	
	if(W1!=0){
        printf("Filter 2...\n");
        return -4;
    }
	
	//FILTER 3
	mpz_powm(y, V1, phiQ, G1Q);
	mpz_powm(tmp, hashValue, phiqQ, G1Q);
	mpz_mul(y, y, tmp);
	mpz_mod(y, y, G1Q);

	mpz_t *x=malloc(2*sizeof(mpz_t));
    mpz_t *modulus=malloc(2*sizeof(mpz_t));

    for(int i=0; i<2; i++) mpz_init(x[i]);
	for(int i=0; i<2; i++) mpz_init(modulus[i]);

	mpz_divexact(x[0], V2, Q);
	mpz_set(x[1], y);

	mpz_set(modulus[0], q);
	mpz_set(modulus[1], G1Q);

	KAZ_DS_CRT(2, x, modulus, SF1);
	
	mpz_mod(tmp, S, G1qQ);
    mpz_sub(W2, tmp, SF1);

    if(mpz_cmp_ui(W2, 0)==0){
        printf("Filter 3...\n");
        return -4;
    }
	
	//FILTER 4
	mpz_powm(y, V1, phiQ, G1);
	mpz_powm(tmp, hashValue, phiqQ, G1);
	mpz_mul(y, y, tmp);
	mpz_mod(y, y, G1);

    for(int i=0; i<2; i++) mpz_init(x[i]);
	for(int i=0; i<2; i++) mpz_init(modulus[i]);

	mpz_divexact(x[0], V2, Q);
	mpz_set(x[1], y);
	
	mpz_gcd(e, Q, G1);
	mpz_divexact(tmp, qQ, e);
	
	mpz_set(modulus[0], tmp);
	mpz_set(modulus[1], G1);

	KAZ_DS_CRT(2, x, modulus, SF2);
	
	mpz_mod(tmp, S, G1qQ);
    mpz_sub(W3, tmp, SF2);

    if(mpz_cmp_ui(W3, 0)==0){
        printf("Filter 4...\n");
        return -4;
    }

	//FILTER 5
    mpz_mul(W4, Q, S);
	mpz_mod(W4, W4, qQ);
	mpz_sub(W5, W4, V2);

    if(mpz_cmp_ui(W5, 0)!=0){
        printf("Filter 5...\n");
        return -4;
    }

    //5) Verifying Procedures
	mpz_powm(y1, R, S, G0);

    mpz_powm(tmp, V1, phiQ, G1);
	mpz_powm(tmp2, hashValue, phiqQ, G1);
    mpz_mul(tmp, tmp, tmp2);
    mpz_mod(tmp, tmp, G1);
	mpz_powm(y2, R, tmp, G0);

    if(mpz_cmp(y1, y2)!=0)
        return -4;

    memcpy(m, MBYTE, len);
    *mlen=len;

    mpz_clears(G0, R, G1, q, Q, phiQ, G1Q, G1qQ, qQ, phiqQ, hashValue, V1, V2, S, NULL);
    mpz_clears(tmp, tmp2, SF1, SF2, W0, W2, W3, W4, W5, VQ, y, y1, y2, NULL);

	return 0;
}